Variational Inference with Implicit Approximate Inference Models (WIP Pt. 2)
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
import numpy as np
import keras.backend as K
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import logistic, multivariate_normal, norm
from scipy.special import expit
from keras.models import Model, Sequential
from keras.layers import Activation, Dense, Dot, Input
from keras.utils.vis_utils import model_to_dot
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.animation import FuncAnimation
from IPython.display import SVG
plt.style.use('seaborn-notebook')
sns.set_context('notebook')
# display animation inline
plt.rc('animation', html='html5')
np.set_printoptions(precision=2,
edgeitems=3,
linewidth=80,
suppress=True)
K.tf.__version__
LATENT_DIM = 2
NOISE_DIM = 3
BATCH_SIZE = 128
D_BATCH_SIZE = 128
G_BATCH_SIZE = 128
PRIOR_VARIANCE = 2.
Bayesian Logistic Regression (Synthetic Data)¶
w_min, w_max = -5, 5
w1, w2 = np.mgrid[w_min:w_max:300j, w_min:w_max:300j]
w_grid = np.dstack((w1, w2))
w_grid.shape
prior = multivariate_normal(mean=np.zeros(LATENT_DIM),
cov=PRIOR_VARIANCE)
log_prior = prior.logpdf(w_grid)
log_prior.shape
fig, ax = plt.subplots(figsize=(5, 5))
ax.contourf(w1, w2, log_prior, cmap='magma')
ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')
ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)
plt.show()
x1 = np.array([ 1.5, 1.])
x2 = np.array([-1.5, 1.])
x3 = np.array([- .5, -1.])
X = np.vstack((x1, x2, x3))
X.shape
y1 = 1
y2 = 1
y3 = 0
y = np.stack((y1, y2, y3))
y.shape
def log_likelihood(w, x, y):
# equiv. to negative binary cross entropy
return np.log(expit(np.dot(w.T, x)*(-1)**(1-y)))
llhs = log_likelihood(w_grid.T, X.T, y)
llhs.shape
fig, axes = plt.subplots(ncols=3, nrows=1, figsize=(6, 2))
fig.tight_layout()
for i, ax in enumerate(axes):
ax.contourf(w1, w2, llhs[::,::,i], cmap=plt.cm.magma)
ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)
ax.set_title('$p(y_{{{0}}} \mid x_{{{0}}}, w)$'.format(i+1))
ax.set_xlabel('$w_1$')
if not i:
ax.set_ylabel('$w_2$')
plt.show()
fig, ax = plt.subplots(figsize=(5, 5))
ax.contourf(w1, w2, np.sum(llhs, axis=2),
cmap=plt.cm.magma)
ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')
ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)
plt.show()
fig, ax = plt.subplots(figsize=(5, 5))
ax.contourf(w1, w2,
np.exp(log_prior+np.sum(llhs, axis=2)),
cmap=plt.cm.magma)
ax.plot(*np.vstack((x1,x2,x3)).T, 'ro')
ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')
ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)
plt.show()
Model Definitions¶
Density Ratio Estimator (Discriminator) Model¶
$T_{\psi}(x, z)$
Here we consider
$T_{\psi}(w)$
$T_{\psi} : \mathbb{R}^2 \to \mathbb{R}$
discriminator = Sequential(name='discriminator')
discriminator.add(Dense(10, input_dim=LATENT_DIM, activation='relu'))
discriminator.add(Dense(20, activation='relu'))
discriminator.add(Dense(1, activation=None, name='logit'))
discriminator.add(Activation('sigmoid'))
discriminator.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['binary_accuracy'])
ratio_estimator = Model(
inputs=discriminator.inputs,
outputs=discriminator.get_layer(name='logit').output)
SVG(model_to_dot(discriminator, show_shapes=True)
.create(prog='dot', format='svg'))
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
Initial density ratio, prior to any training
fig, ax = plt.subplots(figsize=(5, 5))
ax.contourf(w1, w2, w_grid_ratio, cmap=plt.cm.magma)
ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')
ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)
plt.show()
discriminator.evaluate(prior.rvs(size=5), np.zeros(5))
Approximate Inference Model¶
$z_{\phi}(x, \epsilon)$
Here we only consider
$z_{\phi}(\epsilon)$
$z_{\phi}: \mathbb{R}^3 \to \mathbb{R}^2$
inference = Sequential()
inference.add(Dense(10, input_dim=NOISE_DIM, activation='relu'))
inference.add(Dense(20, activation='relu'))
inference.add(Dense(LATENT_DIM, activation=None))
inference.summary()
The variational parameters $\phi$ are the trainable weights of the approximate inference model
phi = inference.trainable_weights
phi
SVG(model_to_dot(inference, show_shapes=True)
.create(prog='dot', format='svg'))
eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
w_posterior_samples = inference.predict(eps)
w_posterior_samples.shape
w_prior_samples = prior.rvs(size=BATCH_SIZE)
w_prior_samples.shape
fig, ax = plt.subplots(figsize=(5, 5))
ax.contourf(w1, w2,
np.exp(log_prior+np.sum(llhs, axis=2)),
cmap=plt.cm.magma)
ax.scatter(*w_posterior_samples.T, alpha=.6)
ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')
ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)
plt.show()
fig, ax = plt.subplots(figsize=(5, 5))
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
ax.contourf(w1, w2, w_grid_ratio, cmap=plt.cm.magma, animate=True)
scatter_posterior = ax.scatter(*w_posterior_samples.T, alpha=.8)
scatter_prior = ax.scatter(*w_prior_samples.T, alpha=.8)
ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')
ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)
props = dict(boxstyle='round', facecolor='w', alpha=0.5)
t = ax.text(0.05, 0.85, 'step: 0',
transform=ax.transAxes, bbox=props)
plt.show()
Discriminator pre-training¶
def prior_samples_gen(batch_size):
while True:
yield prior.rvs(size=batch_size)
def posterior_samples_gen(inference_model, batch_size):
while True:
eps = np.random.randn(batch_size, NOISE_DIM)
yield inference_model.predict(eps)
def discriminator_data_gen(inference_model, batch_size):
for samples_prior, samples_posterior in zip(prior_samples_gen(batch_size),
posterior_samples_gen(inference_model, batch_size)):
inputs = np.vstack((samples_prior, samples_posterior))
targets = np.hstack((np.zeros(batch_size), np.ones(batch_size)))
yield inputs, targets
h = discriminator.fit_generator(generator=discriminator_data_gen(inference, 128), steps_per_epoch=32, epochs=2)
h.history['loss'][-1]
metrics = discriminator.train_on_batch(D_input, D_labels)
def animate(step):
ax.cla()
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
ax.contourf(w1, w2, w_grid_ratio, cmap=plt.cm.magma)
info_dict = dict(zip(discriminator.metrics_names, metrics))
info_dict['step'] = step
props = dict(boxstyle='round', facecolor='w', alpha=0.5)
t = ax.text(0.05, 0.85, 'step: 0',
transform=ax.transAxes, bbox=props)
scatter_posterior = ax.scatter(*w_posterior_samples.T, alpha=.8)
scatter_prior = ax.scatter(*w_prior_samples.T, alpha=.8)
return ax
FuncAnimation(fig, animate, frames=50,
interval=200, # 5 fps
blit=False)
fig, ax = plt.subplots(figsize=(7, 7))
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
cset = ax.contourf(w1, w2, w_grid_ratio, cmap=plt.cm.magma, animate=True)
scatter_posterior = ax.scatter(*w_posterior_samples.T, alpha=.8)
scatter_prior = ax.scatter(*w_prior_samples.T, alpha=.8)
ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')
ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)
props = dict(boxstyle='round', facecolor='w', alpha=0.5)
t = ax.text(0.05, 0.85, 'step: 0',
transform=ax.transAxes, bbox=props)
plt.show()
cset.collections
from matplotlib.collections import PatchCollection
dir(cset.collections[0])
import matplotlib.patches as patches
from matplotlib.collections import PathCollection
fig, ax = plt.subplots(figsize=(7, 7))
ax.add_collection(cset.collections[4])
plt.show()
fig, ax = plt.subplots(figsize=(7, 7))
cset = ax.contourf(np.linspace(-3, 3, 32), np.linspace(-3, 3, 32), np.random.randn(32, 32), cmap='magma')
scat = ax.scatter(*np.random.randn(2, 128), alpha=.8)
ax.set_xlim(-3, 3)
ax.set_ylim(-3, 3)
plt.show()
def animate(step):
ax.cla()
ax.contourf(np.linspace(-3, 3, 32), np.linspace(-3, 3, 32), np.random.randn(32, 32), cmap='magma')
ax.scatter(*np.random.randn(2, 128), alpha=.8)
ax.set_xlim(-3, 3)
ax.set_ylim(-3, 3)
return scat
FuncAnimation(fig, animate, frames=25,
interval=200) # 5 fps
